# -*- coding: utf-8 -*-
from warnings import warn

import numpy as np
import pandas as pd

from ..misc import NeuroKitWarning
from .eda_autocor import eda_autocor
from .eda_sympathetic import eda_sympathetic


def eda_intervalrelated(data, sampling_rate=1000, **kwargs):
    """**EDA Analysis on Interval-Related Data**

    Performs EDA analysis on longer periods of data (typically > 10 seconds), such as resting-state
    data.

    Parameters
    ----------

    data : Union[dict, pd.DataFrame]
        A DataFrame containing the different processed signal(s) as different columns, typically
        generated by :func:`eda_process` or :func:`bio_process`. Can also take a dict containing
        sets of separately processed DataFrames.
    sampling_rate : int
        The sampling frequency of ``ecg_signal`` (in Hz, i.e., samples/second). Defaults to 1000.
    **kwargs
        Other arguments to be passed to the functions.

    Returns
    -------
    DataFrame
        A dataframe containing the analyzed EDA features. The analyzed
        features consist of the following:

        .. codebookadd::
            SCR_Peaks_N|The number of occurrences of Skin Conductance Response (SCR).
            SCR_Peaks_Amplitude_Mean|The mean amplitude of the SCR peak occurrences.
            EDA_Tonic_SD|The mean amplitude of the SCR peak occurrences.

        * ``"EDA_Sympathetic"``: see :func:`eda_sympathetic` (only computed if signal duration
          > 64 sec).
        * ``"EDA_Autocorrelation"``: see :func:`eda_autocor` (only computed if signal duration
          > 30 sec).

    See Also
    --------
    .bio_process, eda_eventrelated

    Examples
    --------
    .. ipython:: python

        import neurokit2 as nk

        # Download data
        data = nk.data("bio_resting_8min_100hz")

        # Process the data
        df, info = nk.eda_process(data["EDA"], sampling_rate=100)

        # Single dataframe is passed
        nk.eda_intervalrelated(df, sampling_rate=100)

        epochs = nk.epochs_create(df, events=[0, 25300], sampling_rate=100, epochs_end=20)
        nk.eda_intervalrelated(epochs, sampling_rate=100)

    """

    # Format input
    if isinstance(data, pd.DataFrame):
        results = _eda_intervalrelated(data, sampling_rate=sampling_rate, **kwargs)
        results = pd.DataFrame.from_dict(results, orient="index").T
    elif isinstance(data, dict):
        results = {}
        for index in data:
            results[index] = {}  # Initialize empty container

            # Add label info
            results[index]["Label"] = data[index]["Label"].iloc[0]

            results[index] = _eda_intervalrelated(data[index], results[index], sampling_rate=sampling_rate, **kwargs)

        results = pd.DataFrame.from_dict(results, orient="index")

    return results


# =============================================================================
# Internals
# =============================================================================


def _eda_intervalrelated(data, output={}, sampling_rate=1000, method_sympathetic="posada", **kwargs):
    """Format input for dictionary."""
    # Sanitize input
    colnames = data.columns.values

    # SCR Peaks
    if "SCR_Peaks" not in colnames:
        warn(
            "We couldn't find an `SCR_Peaks` column. Returning NaN for N peaks.",
            category=NeuroKitWarning,
        )
        output["SCR_Peaks_N"] = np.nan
    else:
        output["SCR_Peaks_N"] = np.nansum(data["SCR_Peaks"].values)

    # Peak amplitude
    if "SCR_Amplitude" not in colnames:
        warn(
            "We couldn't find an `SCR_Amplitude` column. Returning NaN for peak amplitude.",
            category=NeuroKitWarning,
        )
        output["SCR_Peaks_Amplitude_Mean"] = np.nan
    else:
        peaks_idx = data["SCR_Peaks"] == 1
        # Mean amplitude is only computed over peaks. If no peaks, return NaN
        if peaks_idx.sum() > 0:
            output["SCR_Peaks_Amplitude_Mean"] = np.nanmean(data[peaks_idx]["SCR_Amplitude"].values)
        else:
            output["SCR_Peaks_Amplitude_Mean"] = np.nan

    # Get variability of tonic
    if "EDA_Tonic" in colnames:
        output["EDA_Tonic_SD"] = np.nanstd(data["EDA_Tonic"].values)

    # EDA Sympathetic
    output.update({"EDA_Sympathetic": np.nan, "EDA_SympatheticN": np.nan})  # Default values
    if len(data) > sampling_rate * 64:
        if "EDA_Clean" in colnames:
            output.update(
                eda_sympathetic(
                    data["EDA_Clean"],
                    sampling_rate=sampling_rate,
                    method=method_sympathetic,
                )
            )
        elif "EDA_Raw" in colnames:
            # If not clean signal, use raw
            output.update(
                eda_sympathetic(
                    data["EDA_Raw"],
                    sampling_rate=sampling_rate,
                    method=method_sympathetic,
                )
            )

    # EDA autocorrelation
    output.update({"EDA_Autocorrelation": np.nan})  # Default values
    if len(data) > sampling_rate * 30:  # 30 seconds minimum (NOTE: somewhat arbitrary)
        if "EDA_Clean" in colnames:
            output["EDA_Autocorrelation"] = eda_autocor(data["EDA_Clean"], sampling_rate=sampling_rate, **kwargs)
        elif "EDA_Raw" in colnames:
            # If not clean signal, use raw
            output["EDA_Autocorrelation"] = eda_autocor(data["EDA_Raw"], sampling_rate=sampling_rate, **kwargs)

    return output
